// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//
#include <pollux/exec/local_planner.h>
#include <pollux/core/plan_fragment.h>
#include <pollux/exec/arrow_stream.h>
#include <pollux/exec/assign_uniqueId.h>
#include <pollux/exec/callback_sink.h>
#include <pollux/exec/enforce_single_row.h>
#include <pollux/exec/exchange.h>
#include <pollux/exec/expand.h>
#include <pollux/exec/filter_project.h>
#include <pollux/exec/group_id.h>
#include <pollux/exec/hash_aggregation.h>
#include <pollux/exec/hash_build.h>
#include <pollux/exec/hash_probe.h>
#include <pollux/exec/index_lookup_join.h>
#include <pollux/exec/limit.h>
#include <pollux/exec/mark_distinct.h>
#include <pollux/exec/merge.h>
#include <pollux/exec/merge_join.h>
#include <pollux/exec/nested_loop_join_build.h>
#include <pollux/exec/nested_loop_join_probe.h>
#include <pollux/exec/operator_trace_scan.h>
#include <pollux/exec/order_by.h>
#include <pollux/exec/partitioned_output.h>
#include <pollux/exec/round_robin_partition_function.h>
#include <pollux/exec/row_number.h>
#include <pollux/exec/scale_writer_local_partition.h>
#include <pollux/exec/streaming_aggregation.h>
#include <pollux/exec/table_scan.h>
#include <pollux/exec/table_write_merge.h>
#include <pollux/exec/table_writer.h>
#include <pollux/exec/task.h>
#include <pollux/exec/top_n.h>
#include <pollux/exec/top_n_row_number.h>
#include <pollux/exec/unnest.h>
#include <pollux/exec/values.h>
#include <pollux/exec/window.h>

namespace kumo::pollux::exec {
    namespace detail {
        /// Returns true if source nodes must run in a separate pipeline.
        bool mustStartNewPipeline(
            const std::shared_ptr<const core::PlanNode> &planNode,
            int sourceId) {
            if (auto localMerge =
                    std::dynamic_pointer_cast<const core::LocalMergeNode>(planNode)) {
                // LocalMerge's source runs on its own pipeline.
                return true;
            }

            if (std::dynamic_pointer_cast<const core::LocalPartitionNode>(planNode)) {
                return true;
            }

            // Non-first sources always run in their own pipeline.
            return sourceId != 0;
        }

        bool isIndexLookupJoin(core::PlanNodePtr planNode) {
            const auto indexLookupJoin =
                    std::dynamic_pointer_cast<const core::IndexLookupJoinNode>(planNode);
            return indexLookupJoin != nullptr;
        }

        // Creates the customized local partition operator for table writer scaling.
        std::unique_ptr<Operator> createScaleWriterLocalPartition(
            const std::shared_ptr<const core::LocalPartitionNode> &localPartitionNode,
            int32_t operatorId,
            DriverCtx *ctx) {
            if (dynamic_cast<const RoundRobinPartitionFunctionSpec *>(
                &localPartitionNode->partitionFunctionSpec())) {
                return std::make_unique<ScaleWriterLocalPartition>(
                    operatorId, ctx, localPartitionNode);
            }

            return std::make_unique<ScaleWriterPartitioningLocalPartition>(
                operatorId, ctx, localPartitionNode);
        }

        OperatorSupplier makeConsumerSupplier(ConsumerSupplier consumerSupplier) {
            if (consumerSupplier) {
                return [consumerSupplier = std::move(consumerSupplier)](
                    int32_t operatorId, DriverCtx *ctx) {
                    return std::make_unique<CallbackSink>(
                        operatorId, ctx, consumerSupplier());
                };
            }
            return nullptr;
        }

        OperatorSupplier makeConsumerSupplier(
            const std::shared_ptr<const core::PlanNode> &planNode) {
            if (auto localMerge =
                    std::dynamic_pointer_cast<const core::LocalMergeNode>(planNode)) {
                return [localMerge](int32_t operatorId, DriverCtx *ctx) {
                    auto mergeSource = ctx->task->addLocalMergeSource(
                        ctx->splitGroupId, localMerge->id(), localMerge->outputType());

                    auto consumer = [mergeSource](
                        RowVectorPtr input, ContinueFuture *future) {
                        return mergeSource->enqueue(std::move(input), future);
                    };
                    return std::make_unique<CallbackSink>(operatorId, ctx, consumer);
                };
            }

            if (auto localPartitionNode =
                    std::dynamic_pointer_cast<const core::LocalPartitionNode>(planNode)) {
                if (localPartitionNode->scaleWriter()) {
                    return [localPartitionNode](int32_t operatorId, DriverCtx *ctx) {
                        return createScaleWriterLocalPartition(
                            localPartitionNode, operatorId, ctx);
                    };
                }
                return [localPartitionNode](int32_t operatorId, DriverCtx *ctx) {
                    return std::make_unique<LocalPartition>(
                        operatorId, ctx, localPartitionNode);
                };
            }

            if (auto join =
                    std::dynamic_pointer_cast<const core::HashJoinNode>(planNode)) {
                return [join](int32_t operatorId, DriverCtx *ctx) {
                    return std::make_unique<HashBuild>(operatorId, ctx, join);
                };
            }

            if (auto join =
                    std::dynamic_pointer_cast<const core::NestedLoopJoinNode>(planNode)) {
                return [join](int32_t operatorId, DriverCtx *ctx) {
                    return std::make_unique<NestedLoopJoinBuild>(operatorId, ctx, join);
                };
            }

            if (auto join =
                    std::dynamic_pointer_cast<const core::MergeJoinNode>(planNode)) {
                auto planNodeId = planNode->id();
                return [planNodeId](int32_t operatorId, DriverCtx *ctx) {
                    auto source =
                            ctx->task->getMergeJoinSource(ctx->splitGroupId, planNodeId);
                    auto consumer = [source](RowVectorPtr input, ContinueFuture *future) {
                        return source->enqueue(std::move(input), future);
                    };
                    return std::make_unique<CallbackSink>(operatorId, ctx, consumer);
                };
            }

            return Operator::operatorSupplierFromPlanNode(planNode);
        }

        void plan(
            const std::shared_ptr<const core::PlanNode> &planNode,
            std::vector<std::shared_ptr<const core::PlanNode> > *currentPlanNodes,
            const std::shared_ptr<const core::PlanNode> &consumerNode,
            OperatorSupplier consumerSupplier,
            std::vector<std::unique_ptr<DriverFactory> > *driverFactories) {
            if (!currentPlanNodes) {
                driverFactories->push_back(std::make_unique<DriverFactory>());
                currentPlanNodes = &driverFactories->back()->planNodes;
                driverFactories->back()->consumerSupplier = std::move(consumerSupplier);
                driverFactories->back()->consumerNode = consumerNode;
            }

            const auto &sources = planNode->sources();
            if (sources.empty()) {
                driverFactories->back()->inputDriver = true;
            } else {
                const auto numSourcesToPlan =
                        isIndexLookupJoin(planNode) ? 1 : sources.size();
                for (int32_t i = 0; i < numSourcesToPlan; ++i) {
                    plan(
                        sources[i],
                        mustStartNewPipeline(planNode, i) ? nullptr : currentPlanNodes,
                        planNode,
                        makeConsumerSupplier(planNode),
                        driverFactories);
                }
            }

            currentPlanNodes->push_back(planNode);
        }

        // Sometimes consumer limits the number of drivers its producer can run.
        uint32_t maxDriversForConsumer(
            const std::shared_ptr<const core::PlanNode> &node) {
            if (std::dynamic_pointer_cast<const core::MergeJoinNode>(node)) {
                // MergeJoinNode must run single-threaded.
                return 1;
            }
            return std::numeric_limits<uint32_t>::max();
        }

        uint32_t maxDrivers(
            const DriverFactory &driverFactory,
            const core::QueryConfig &queryConfig) {
            uint32_t count = maxDriversForConsumer(driverFactory.consumerNode);
            if (count == 1) {
                return count;
            }
            for (auto &node: driverFactory.planNodes) {
                if (auto topN = std::dynamic_pointer_cast<const core::TopNNode>(node)) {
                    if (!topN->isPartial()) {
                        // final topN must run single-threaded
                        return 1;
                    }
                } else if (
                    auto values = std::dynamic_pointer_cast<const core::ValuesNode>(node)) {
                    // values node must run single-threaded, unless in test context
                    if (!values->isParallelizable()) {
                        return 1;
                    }
                } else if (std::dynamic_pointer_cast<const core::ArrowStreamNode>(node)) {
                    // ArrowStream node must run single-threaded.
                    return 1;
                } else if (
                    auto limit = std::dynamic_pointer_cast<const core::LimitNode>(node)) {
                    // final limit must run single-threaded
                    if (!limit->isPartial()) {
                        return 1;
                    }
                } else if (
                    auto orderBy =
                            std::dynamic_pointer_cast<const core::OrderByNode>(node)) {
                    // final orderby must run single-threaded
                    if (!orderBy->isPartial()) {
                        return 1;
                    }
                } else if (
                    auto localExchange =
                            std::dynamic_pointer_cast<const core::LocalPartitionNode>(node)) {
                    // Local gather must run single-threaded.
                    switch (localExchange->type()) {
                        case core::LocalPartitionNode::Type::kGather:
                            return 1;
                        case core::LocalPartitionNode::Type::kRepartition:
                            count = std::min(queryConfig.maxLocalExchangePartitionCount(), count);
                            break;
                        default:
                            POLLUX_UNREACHABLE("Unexpected local exchange type");
                    }
                } else if (std::dynamic_pointer_cast<const core::LocalMergeNode>(node)) {
                    // Local merge must run single-threaded.
                    return 1;
                } else if (std::dynamic_pointer_cast<const core::MergeExchangeNode>(node)) {
                    // Merge exchange must run single-threaded.
                    return 1;
                } else if (std::dynamic_pointer_cast<const core::MergeJoinNode>(node)) {
                    // Merge join must run single-threaded.
                    return 1;
                } else if (
                    auto join = std::dynamic_pointer_cast<const core::HashJoinNode>(node)) {
                    // Right semi project doesn't support multi-threaded execution.
                    if (join->isRightSemiProjectJoin()) {
                        return 1;
                    }
                } else if (
                    auto tableWrite =
                            std::dynamic_pointer_cast<const core::TableWriteNode>(node)) {
                    const auto &connectorInsertHandle =
                            tableWrite->insertTableHandle()->connectorInsertTableHandle();
                    if (!connectorInsertHandle->supportsMultiThreading()) {
                        return 1;
                    } else {
                        if (tableWrite->hasPartitioningScheme()) {
                            return queryConfig.taskPartitionedWriterCount();
                        } else {
                            return queryConfig.taskWriterCount();
                        }
                    }
                } else {
                    auto result = Operator::maxDrivers(node);
                    if (result) {
                        POLLUX_CHECK_GT(
                            *result,
                            0,
                            "maxDrivers must be greater than 0. Plan node: {}",
                            node->toString());
                        if (*result == 1) {
                            return 1;
                        }
                        count = std::min(*result, count);
                    }
                }
            }
            return count;
        }
    } // namespace detail

    // static
    void LocalPlanner::plan(
        const core::PlanFragment &planFragment,
        ConsumerSupplier consumerSupplier,
        std::vector<std::unique_ptr<DriverFactory> > *driverFactories,
        const core::QueryConfig &queryConfig,
        uint32_t maxDrivers) {
        for (auto &adapter: DriverFactory::adapters) {
            if (adapter.inspect) {
                adapter.inspect(planFragment);
            }
        }
        detail::plan(
            planFragment.planNode,
            nullptr,
            nullptr,
            detail::makeConsumerSupplier(consumerSupplier),
            driverFactories);

        (*driverFactories)[0]->outputDriver = true;

        if (planFragment.isGroupedExecution()) {
            determineGroupedExecutionPipelines(planFragment, *driverFactories);
            markMixedJoinBridges(*driverFactories);
        }

        // Determine number of drivers for each pipeline.
        for (auto &factory: *driverFactories) {
            factory->maxDrivers = detail::maxDrivers(*factory, queryConfig);
            factory->numDrivers = std::min(factory->maxDrivers, maxDrivers);

            // Pipelines running grouped/bucketed execution would have separate groups
            // of drivers dealing with separate split groups (one driver can access
            // splits from only one designated split group), hence we will have total
            // number of drivers multiplied by the number of split groups.
            if (factory->groupedExecution) {
                factory->numTotalDrivers =
                        factory->numDrivers * planFragment.numSplitGroups;
            } else {
                factory->numTotalDrivers = factory->numDrivers;
            }
        }
    }

    // static
    void LocalPlanner::determineGroupedExecutionPipelines(
        const core::PlanFragment &planFragment,
        std::vector<std::unique_ptr<DriverFactory> > &driverFactories) {
        // We run backwards - from leaf pipelines to the root pipeline.
        for (auto it = driverFactories.rbegin(); it != driverFactories.rend(); ++it) {
            auto &factory = *it;

            // See if pipelines have leaf nodes that use grouped execution strategy.
            if (planFragment.leafNodeRunsGroupedExecution(factory->leafNodeId())) {
                factory->groupedExecution = true;
            }

            // If a pipeline's leaf node is Local Partition, which has all sources
            // belonging to pipelines that run Grouped Execution, then our pipeline
            // should run Grouped Execution as well.
            if (auto localPartitionNode =
                    std::dynamic_pointer_cast<const core::LocalPartitionNode>(
                        factory->planNodes.front())) {
                size_t numGroupedExecutionSources{0};
                for (const auto &sourceNode: localPartitionNode->sources()) {
                    for (auto &anotherFactory: driverFactories) {
                        if (sourceNode == anotherFactory->planNodes.back() and
                            anotherFactory->groupedExecution) {
                            ++numGroupedExecutionSources;
                            break;
                        }
                    }
                }
                if (numGroupedExecutionSources > 0 and
                    numGroupedExecutionSources == localPartitionNode->sources().size()) {
                    factory->groupedExecution = true;
                }
            }
        }
    }

    // static
    void LocalPlanner::markMixedJoinBridges(
        std::vector<std::unique_ptr<DriverFactory> > &driverFactories) {
        for (auto &factory: driverFactories) {
            // We are interested in grouped execution pipelines only.
            if (!factory->groupedExecution) {
                continue;
            }

            // See if we have any join nodes.
            for (const auto &planNode: factory->planNodes) {
                if (auto joinNode =
                        std::dynamic_pointer_cast<const core::HashJoinNode>(planNode)) {
                    // See if the build source (2nd) belongs to an ungrouped execution.
                    auto &buildSourceNode = planNode->sources()[1];
                    for (auto &factoryOther: driverFactories) {
                        if (!factoryOther->groupedExecution &&
                            buildSourceNode->id() == factoryOther->outputNodeId()) {
                            factoryOther->mixedExecutionModeHashJoinNodeIds.emplace(
                                planNode->id());
                            factory->mixedExecutionModeHashJoinNodeIds.emplace(planNode->id());
                            break;
                        }
                    }
                } else if (
                    auto joinNode =
                            std::dynamic_pointer_cast<const core::NestedLoopJoinNode>(
                                planNode)) {
                    // See if the build source (2nd) belongs to an ungrouped execution.
                    auto &buildSourceNode = planNode->sources()[1];
                    for (auto &factoryOther: driverFactories) {
                        if (!factoryOther->groupedExecution &&
                            buildSourceNode->id() == factoryOther->outputNodeId()) {
                            factoryOther->mixedExecutionModeNestedLoopJoinNodeIds.emplace(
                                planNode->id());
                            factory->mixedExecutionModeNestedLoopJoinNodeIds.emplace(
                                planNode->id());
                            break;
                        }
                    }
                }
            }
        }
    }

    namespace {
        // If the upstream is partial limit, downstream is final limit and we want to
        // flush as soon as we can to reach the limit and do as little work as possible.
        bool eagerFlush(const core::PlanNode &node) {
            if (auto *limit = dynamic_cast<const core::LimitNode *>(&node)) {
                return limit->isPartial() && limit->offset() + limit->count() < 10'000;
            }
            if (node.sources().empty()) {
                return false;
            }
            // Follow the first source, which is driving the output.
            return eagerFlush(*node.sources()[0]);
        }
    } // namespace

    std::shared_ptr<Driver> DriverFactory::createDriver(
        std::unique_ptr<DriverCtx> ctx,
        std::shared_ptr<ExchangeClient> exchangeClient,
        std::function<int(int pipelineId)> numDrivers) {
        auto driver = std::shared_ptr<Driver>(new Driver());
        ctx->driver = driver.get();
        std::vector<std::unique_ptr<Operator> > operators;
        operators.reserve(planNodes.size());

        for (int32_t i = 0; i < planNodes.size(); i++) {
            // Id of the Operator being made. This is not the same as 'i'
            // because some PlanNodes may get fused.
            auto id = operators.size();
            auto planNode = planNodes[i];
            if (auto filterNode =
                    std::dynamic_pointer_cast<const core::FilterNode>(planNode)) {
                if (i < planNodes.size() - 1) {
                    auto next = planNodes[i + 1];
                    if (auto projectNode =
                            std::dynamic_pointer_cast<const core::ProjectNode>(next)) {
                        operators.push_back(std::make_unique<FilterProject>(
                            id, ctx.get(), filterNode, projectNode));
                        i++;
                        continue;
                    }
                }
                operators.push_back(
                    std::make_unique<FilterProject>(id, ctx.get(), filterNode, nullptr));
            } else if (
                auto projectNode =
                        std::dynamic_pointer_cast<const core::ProjectNode>(planNode)) {
                operators.push_back(
                    std::make_unique<FilterProject>(id, ctx.get(), nullptr, projectNode));
            } else if (
                auto valuesNode =
                        std::dynamic_pointer_cast<const core::ValuesNode>(planNode)) {
                operators.push_back(std::make_unique<Values>(id, ctx.get(), valuesNode));
            } else if (
                auto arrowStreamNode =
                        std::dynamic_pointer_cast<const core::ArrowStreamNode>(planNode)) {
                operators.push_back(
                    std::make_unique<ArrowStream>(id, ctx.get(), arrowStreamNode));
            } else if (
                auto tableScanNode =
                        std::dynamic_pointer_cast<const core::TableScanNode>(planNode)) {
                operators.push_back(
                    std::make_unique<TableScan>(id, ctx.get(), tableScanNode));
            } else if (
                auto tableWriteNode =
                        std::dynamic_pointer_cast<const core::TableWriteNode>(planNode)) {
                operators.push_back(
                    std::make_unique<TableWriter>(id, ctx.get(), tableWriteNode));
            } else if (
                auto tableWriteMergeNode =
                        std::dynamic_pointer_cast<const core::TableWriteMergeNode>(
                            planNode)) {
                operators.push_back(std::make_unique<TableWriteMerge>(
                    id, ctx.get(), tableWriteMergeNode));
            } else if (
                auto mergeExchangeNode =
                        std::dynamic_pointer_cast<const core::MergeExchangeNode>(
                            planNode)) {
                operators.push_back(
                    std::make_unique<MergeExchange>(i, ctx.get(), mergeExchangeNode));
            } else if (
                auto exchangeNode =
                        std::dynamic_pointer_cast<const core::ExchangeNode>(planNode)) {
                // NOTE: the exchange client can only be used by one operator in a driver.
                POLLUX_CHECK_NOT_NULL(exchangeClient);
                operators.push_back(std::make_unique<Exchange>(
                    id, ctx.get(), exchangeNode, std::move(exchangeClient)));
            } else if (
                auto partitionedOutputNode =
                        std::dynamic_pointer_cast<const core::PartitionedOutputNode>(
                            planNode)) {
                operators.push_back(std::make_unique<PartitionedOutput>(
                    id, ctx.get(), partitionedOutputNode, eagerFlush(*planNode)));
            } else if (
                auto joinNode =
                        std::dynamic_pointer_cast<const core::HashJoinNode>(planNode)) {
                operators.push_back(std::make_unique<HashProbe>(id, ctx.get(), joinNode));
            } else if (
                auto joinNode =
                        std::dynamic_pointer_cast<const core::NestedLoopJoinNode>(
                            planNode)) {
                operators.push_back(
                    std::make_unique<NestedLoopJoinProbe>(id, ctx.get(), joinNode));
            } else if (
                auto joinNode =
                        std::dynamic_pointer_cast<const core::IndexLookupJoinNode>(
                            planNode)) {
                operators.push_back(
                    std::make_unique<IndexLookupJoin>(id, ctx.get(), joinNode));
            } else if (
                auto aggregationNode =
                        std::dynamic_pointer_cast<const core::AggregationNode>(planNode)) {
                if (aggregationNode->isPreGrouped()) {
                    operators.push_back(std::make_unique<StreamingAggregation>(
                        id, ctx.get(), aggregationNode));
                } else {
                    operators.push_back(
                        std::make_unique<HashAggregation>(id, ctx.get(), aggregationNode));
                }
            } else if (
                auto expandNode =
                        std::dynamic_pointer_cast<const core::ExpandNode>(planNode)) {
                operators.push_back(std::make_unique<Expand>(id, ctx.get(), expandNode));
            } else if (
                auto groupIdNode =
                        std::dynamic_pointer_cast<const core::GroupIdNode>(planNode)) {
                operators.push_back(
                    std::make_unique<GroupId>(id, ctx.get(), groupIdNode));
            } else if (
                auto topNNode =
                        std::dynamic_pointer_cast<const core::TopNNode>(planNode)) {
                operators.push_back(std::make_unique<TopN>(id, ctx.get(), topNNode));
            } else if (
                auto limitNode =
                        std::dynamic_pointer_cast<const core::LimitNode>(planNode)) {
                operators.push_back(std::make_unique<Limit>(id, ctx.get(), limitNode));
            } else if (
                auto orderByNode =
                        std::dynamic_pointer_cast<const core::OrderByNode>(planNode)) {
                operators.push_back(
                    std::make_unique<OrderBy>(id, ctx.get(), orderByNode));
            } else if (
                auto windowNode =
                        std::dynamic_pointer_cast<const core::WindowNode>(planNode)) {
                operators.push_back(std::make_unique<Window>(id, ctx.get(), windowNode));
            } else if (
                auto rowNumberNode =
                        std::dynamic_pointer_cast<const core::RowNumberNode>(planNode)) {
                operators.push_back(
                    std::make_unique<RowNumber>(id, ctx.get(), rowNumberNode));
            } else if (
                auto topNRowNumberNode =
                        std::dynamic_pointer_cast<const core::TopNRowNumberNode>(
                            planNode)) {
                operators.push_back(
                    std::make_unique<TopNRowNumber>(id, ctx.get(), topNRowNumberNode));
            } else if (
                auto markDistinctNode =
                        std::dynamic_pointer_cast<const core::MarkDistinctNode>(planNode)) {
                operators.push_back(
                    std::make_unique<MarkDistinct>(id, ctx.get(), markDistinctNode));
            } else if (
                auto localMerge =
                        std::dynamic_pointer_cast<const core::LocalMergeNode>(planNode)) {
                auto localMergeOp =
                        std::make_unique<LocalMerge>(id, ctx.get(), localMerge);
                operators.push_back(std::move(localMergeOp));
            } else if (
                auto mergeJoin =
                        std::dynamic_pointer_cast<const core::MergeJoinNode>(planNode)) {
                auto mergeJoinOp = std::make_unique<MergeJoin>(id, ctx.get(), mergeJoin);
                ctx->task->createMergeJoinSource(ctx->splitGroupId, mergeJoin->id());
                operators.push_back(std::move(mergeJoinOp));
            } else if (
                auto localPartitionNode =
                        std::dynamic_pointer_cast<const core::LocalPartitionNode>(
                            planNode)) {
                operators.push_back(std::make_unique<LocalExchange>(
                    id,
                    ctx.get(),
                    localPartitionNode->outputType(),
                    localPartitionNode->id(),
                    ctx->partitionId));
            } else if (
                auto unnest =
                        std::dynamic_pointer_cast<const core::UnnestNode>(planNode)) {
                operators.push_back(std::make_unique<Unnest>(id, ctx.get(), unnest));
            } else if (
                auto enforceSingleRow =
                        std::dynamic_pointer_cast<const core::EnforceSingleRowNode>(
                            planNode)) {
                operators.push_back(
                    std::make_unique<EnforceSingleRow>(id, ctx.get(), enforceSingleRow));
            } else if (
                auto assignUniqueIdNode =
                        std::dynamic_pointer_cast<const core::AssignUniqueIdNode>(
                            planNode)) {
                operators.push_back(std::make_unique<AssignUniqueId>(
                    id,
                    ctx.get(),
                    assignUniqueIdNode,
                    assignUniqueIdNode->taskUniqueId(),
                    assignUniqueIdNode->uniqueIdCounter()));
            } else if (
                const auto traceScanNode =
                        std::dynamic_pointer_cast<const core::TraceScanNode>(planNode)) {
                operators.push_back(std::make_unique<trace::OperatorTraceScan>(
                    id, ctx.get(), traceScanNode));
            } else {
                std::unique_ptr<Operator> extended;
                if (planNode->requiresExchangeClient()) {
                    // NOTE: the exchange client can only be used by one operator in a
                    // driver.
                    POLLUX_CHECK_NOT_NULL(exchangeClient);
                    extended = Operator::fromPlanNode(
                        ctx.get(), id, planNode, std::move(exchangeClient));
                } else {
                    extended = Operator::fromPlanNode(ctx.get(), id, planNode);
                }
                POLLUX_CHECK(extended, "Unsupported plan node: {}", planNode->toString());
                operators.push_back(std::move(extended));
            }
        }
        if (consumerSupplier) {
            operators.push_back(consumerSupplier(operators.size(), ctx.get()));
        }

        driver->init(std::move(ctx), std::move(operators));
        for (auto &adapter: adapters) {
            if (adapter.adapt(*this, *driver)) {
                break;
            }
        }
        driver->isAdaptable_ = false;
        return driver;
    }

    std::vector<std::unique_ptr<Operator> > DriverFactory::replaceOperators(
        Driver &driver,
        int32_t begin,
        int32_t end,
        std::vector<std::unique_ptr<Operator> > replaceWith) const {
        POLLUX_CHECK(driver.isAdaptable_);
        std::vector<std::unique_ptr<exec::Operator> > replaced;
        for (auto i = begin; i < end; ++i) {
            replaced.push_back(std::move(driver.operators_[i]));
        }

        driver.operators_.erase(
            driver.operators_.begin() + begin, driver.operators_.begin() + end);

        // Insert the replacement at the place of the erase. Do manually because
        // insert() is not good with unique pointers.
        driver.operators_.resize(driver.operators_.size() + replaceWith.size());
        for (int32_t i = driver.operators_.size() - 1;
             i >= begin + replaceWith.size();
             --i) {
            driver.operators_[i] = std::move(driver.operators_[i - replaceWith.size()]);
        }
        for (auto i = 0; i < replaceWith.size(); ++i) {
            driver.operators_[i + begin] = std::move(replaceWith[i]);
        }

        // Set the ids to be consecutive.
        for (auto i = 0; i < driver.operators_.size(); ++i) {
            driver.operators_[i]->setOperatorIdFromAdapter(i);
        }
        return replaced;
    }

    std::vector<core::PlanNodeId> DriverFactory::needsHashJoinBridges() const {
        std::vector<core::PlanNodeId> planNodeIds;
        // Ungrouped execution pipelines need to take care of cross-mode bridges.
        if (!groupedExecution && !mixedExecutionModeHashJoinNodeIds.empty()) {
            planNodeIds.insert(
                planNodeIds.end(),
                mixedExecutionModeHashJoinNodeIds.begin(),
                mixedExecutionModeHashJoinNodeIds.end());
        }
        for (const auto &planNode: planNodes) {
            if (auto joinNode =
                    std::dynamic_pointer_cast<const core::HashJoinNode>(planNode)) {
                // Grouped execution pipelines should not create cross-mode bridges.
                if (!groupedExecution ||
                    !mixedExecutionModeHashJoinNodeIds.contains(joinNode->id())) {
                    planNodeIds.emplace_back(joinNode->id());
                }
            }
        }
        return planNodeIds;
    }

    std::vector<core::PlanNodeId> DriverFactory::needsNestedLoopJoinBridges()
    const {
        std::vector<core::PlanNodeId> planNodeIds;
        // Ungrouped execution pipelines need to take care of cross-mode bridges.
        if (!groupedExecution && !mixedExecutionModeNestedLoopJoinNodeIds.empty()) {
            planNodeIds.insert(
                planNodeIds.end(),
                mixedExecutionModeNestedLoopJoinNodeIds.begin(),
                mixedExecutionModeNestedLoopJoinNodeIds.end());
        }
        for (const auto &planNode: planNodes) {
            if (auto joinNode =
                    std::dynamic_pointer_cast<const core::NestedLoopJoinNode>(
                        planNode)) {
                // Grouped execution pipelines should not create cross-mode bridges.
                if (!groupedExecution ||
                    !mixedExecutionModeNestedLoopJoinNodeIds.contains(joinNode->id())) {
                    planNodeIds.emplace_back(joinNode->id());
                }
            }
        }

        return planNodeIds;
    }

    // static
    void DriverFactory::registerAdapter(DriverAdapter adapter) {
        adapters.push_back(std::move(adapter));
    }

    // static
    std::vector<DriverAdapter> DriverFactory::adapters;
} // namespace kumo::pollux::exec
